import torch

import model
from dataset import CelestialDataset
import torch.utils.data as tud
from torchvision import models
from torch import nn

from main import BATCH_SIZE, SHUFFLE, NUM_WORKERS

if __name__ == "__main__":
    celestial_data = CelestialDataset(action='test')
    testing_celestial_dl = tud.DataLoader(dataset=celestial_data,
                                          batch_size=BATCH_SIZE,
                                          shuffle=SHUFFLE,
                                          num_workers=NUM_WORKERS)
    model_googlenet = models.googlenet(pretrained=True).to(model.device)
    model_googlenet.load_state_dict(torch.load('./googlenet.pth'))

    # set loss function to cross_entropy loss function
    criterion = nn.CrossEntropyLoss()

    # start testing model
    model.model_test(model_googlenet, criterion, testing_celestial_dl)
